Skip to content

[KDA] sm90 GVA enhance#64

Open
sjmshsh wants to merge 19 commits intoinclusionAI:mainfrom
sjmshsh:feature/kda_gva_support
Open

[KDA] sm90 GVA enhance#64
sjmshsh wants to merge 19 commits intoinclusionAI:mainfrom
sjmshsh:feature/kda_gva_support

Conversation

@sjmshsh
Copy link
Copy Markdown

@sjmshsh sjmshsh commented May 6, 2026

feat(kda, sm90): add KDA GVA forward support

Summary

Add Grouped Value Attention (GVA) forward support to the KDA Hopper / SM90
prefill path: Q/K share num_qk_heads, while V, g, β, O and the recurrent
state are sized by num_v_heads.

Constraints

  • num_q_heads == num_k_heads
  • num_v_heads >= num_qk_heads
  • num_v_heads % num_qk_heads == 0
  • head_dim == 128 (unchanged)

heads_per_group = num_v_heads / num_qk_heads. Each value head is mapped to
exactly one shared Q/K head:

heads_per_group = num_v_heads / num_qk_heads;
qk_head_idx     = v_head_idx / heads_per_group;

When num_v_heads == num_qk_heads this degenerates back to standard MHA.

What changed

Across the full Python → C++ → kernel stack:

  • Python wrapper (cula/kda/hopper_fused_fwd.py)
    • Accept independent num_qk_heads and num_v_heads.
    • Per-tensor shape validation: q,k = [B,T,H,D], v,g = [B,T,HV,D],
      beta = [B,T,HV], initial_state = [N,HV,D,D].
    • Fix: respect output_final_state=False (the wrapper used to leak the
      kernel-allocated state tensor).
  • C++ entry (csrc/api/kda_sm90.cu)
    • Split num_headsnum_qk_heads / num_v_heads.
    • Allocate output [packed, HV, D] and output_state [N, HV, D, D].
    • Validate α, β, input_state shapes against num_v_heads.
    • Defensive checks: num_qk_heads > 0, num_v_heads > 0, then % == 0,
      so a degenerate input never triggers division-by-zero UB.
  • Kernel skeleton (kernel_kda_fwd.hpp, prefill_kernel*.hpp/.cuh,
    kda_fwd_sm90.cu, kda_fwd_sm90_safe_gate.cu)
    • VarlenProblemShape carries both num_qk_heads and num_v_heads.
    • Two strides per token row: qk_tok_stride = num_qk_heads * head_size
      for Q/K; v_tok_stride = num_v_heads * head_size for V/O/α.
    • All explicit template instantiations and signatures updated.
  • Tile scheduler (csrc/kda/sm90/kernel/tile_scheduler.hpp)
    • WorkDesc now exposes q_head_idx() / k_head_idx() returning
      qk_head_idx, while v_head_idx() / o_head_idx() return head_idx.
    • grid.x = num_seqs * num_v_heads; one program per (seq, v_head).
    • heads_per_group is computed once on the host in
      to_underlying_arguments and stored in Params, so the device path
      avoids recomputing num_v_heads / num_qk_heads per CTA.
  • TMA load/store + mainloop
    • load_tma.hpp: Q is sliced over num_qk_heads, K is sliced over
      num_qk_heads, V/α/O are sliced over num_v_heads. The K-vs-V branch
      uses a constexpr selector since LoadKind is a static template
      parameter, so the head-count selection collapses at compile time.
    • store_tma.hpp: O written in the V/O head space.
    • mainloop_kda_fwd.hpp: QK GEMM batched over num_qk_heads; KV / α /
      output / state buffers all batched over num_v_heads.

Tests

pytest tests/test_kda_fused_fwd.py

========================================================================================= test session starts =========================================================================================
platform linux -- Python 3.12.3, pytest-9.0.3, pluggy-1.6.0 -- /root/miniconda3/bin/python
cachedir: .pytest_cache
rootdir: /root/autodl-tmp/cuLA
configfile: pyproject.toml
plugins: anyio-4.9.0
collected 50 items                                                                                                                                                                                    

tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B1-T63-H1-HV1-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                                 [  2%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B1-T63-H1-HV1-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                                 [  4%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T500-H3-HV3-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                                [  6%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T500-H3-HV3-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                                [  8%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T1000-H3-HV3-D128-gln1-mask_p0.5-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                             [ 10%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T1000-H3-HV3-D128-gln1-mask_p0.5-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                             [ 12%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B3-T1024-H4-HV4-D128-gln0.1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                             [ 14%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B3-T1024-H4-HV4-D128-gln0.1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                             [ 16%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B4-T1024-H4-HV4-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                               [ 18%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B4-T1024-H4-HV4-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                               [ 20%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B4-T1024-H4-HV4-D128-gln1-mask_p0-l2normTrue-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                                [ 22%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B4-T1024-H4-HV4-D128-gln1-mask_p0-l2normTrue-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                                [ 24%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T1500-H4-HV4-D128-gln10-mask_p0-l2normFalse-gateTrue-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                               [ 26%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T1500-H4-HV4-D128-gln10-mask_p0-l2normFalse-gateTrue-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                               [ 28%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B4-T2048-H8-HV8-D128-gln1-mask_p0-l2normFalse-gateTrue-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                                [ 30%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B4-T2048-H8-HV8-D128-gln1-mask_p0-l2normFalse-gateTrue-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                                [ 32%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T512-H2-HV4-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                                [ 34%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T512-H2-HV4-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                                [ 36%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T1024-H2-HV8-D128-gln1-mask_p0-l2normFalse-gateTrue-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                                [ 38%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B2-T1024-H2-HV8-D128-gln1-mask_p0-l2normFalse-gateTrue-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                                [ 40%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B1-T64-H1-HV2-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_fp32] PASSED                                 [ 42%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B1-T64-H1-HV2-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initTrue-torch.bfloat16-beta_bf16] PASSED                                 [ 44%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B1-T65-H1-HV4-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initFalse-torch.bfloat16-beta_fp32] PASSED                                [ 46%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk[B1-T65-H1-HV4-D128-gln1-mask_p0-l2normFalse-gateFalse-safe_gateTrue-initFalse-torch.bfloat16-beta_bf16] PASSED                                [ 48%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0.1-cu_seqlens[0, 15]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED                                        [ 50%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0.1-cu_seqlens[0, 15]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED                                        [ 52%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0.9-cu_seqlens[0, 256, 500, 1000]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED                            [ 54%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0.9-cu_seqlens[0, 256, 500, 1000]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED                            [ 56%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0.5-cu_seqlens[0, 256, 500, 1000]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED                            [ 58%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0.5-cu_seqlens[0, 256, 500, 1000]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED                            [ 60%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0-cu_seqlens[0, 15, 100, 300, 1200, 2000]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED                    [ 62%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0-cu_seqlens[0, 15, 100, 300, 1200, 2000]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED                    [ 64%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0-cu_seqlens[0, 100, 300, 1200, 3000, 4096]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED                  [ 66%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H4-HV4-D128-mask_p0-cu_seqlens[0, 100, 300, 1200, 3000, 4096]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED                  [ 68%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H2-HV4-D128-mask_p0-cu_seqlens[0, 63, 130]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED                                     [ 70%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H2-HV4-D128-mask_p0-cu_seqlens[0, 63, 130]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED                                     [ 72%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H1-HV2-D128-mask_p0-cu_seqlens[0, 1]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED                                           [ 74%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H1-HV2-D128-mask_p0-cu_seqlens[0, 1]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED                                           [ 76%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H1-HV2-D128-mask_p0-cu_seqlens[0, 63, 64, 65]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED                                  [ 78%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H1-HV2-D128-mask_p0-cu_seqlens[0, 63, 64, 65]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED                                  [ 80%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H2-HV4-D128-mask_p0-cu_seqlens[0, 17, 64, 65, 130]-torch.bfloat16-safe_gateTrue-initFalse-beta_fp32] PASSED                            [ 82%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H2-HV4-D128-mask_p0-cu_seqlens[0, 17, 64, 65, 130]-torch.bfloat16-safe_gateTrue-initFalse-beta_bf16] PASSED                            [ 84%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H32-HV32-D128-mask_p0-cu_seqlens[0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED [ 86%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H32-HV32-D128-mask_p0-cu_seqlens[0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED [ 88%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H32-HV32-D128-mask_p0-cu_seqlens[0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED [ 90%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H32-HV32-D128-mask_p0-cu_seqlens[0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED [ 92%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H32-HV32-D128-mask_p0-cu_seqlens[0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED [ 94%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H32-HV32-D128-mask_p0-cu_seqlens[0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED [ 96%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H32-HV32-D128-mask_p0-cu_seqlens[0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192]-torch.bfloat16-safe_gateTrue-initTrue-beta_fp32] PASSED [ 98%]
tests/test_kda_fused_fwd.py::test_safe_gate_chunk_varlen[H32-HV32-D128-mask_p0-cu_seqlens[0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192]-torch.bfloat16-safe_gateTrue-initTrue-beta_bf16] PASSED [100%]

========================================================================================= 50 passed in 39.79s =========================================================================================

Performance

python benchmarks/bench_kda_fused_fwd.py
[Device] NVIDIA H20  compute capability sm90  →  using cula.kda.hopper_fused_fwd.cula_kda_prefill

====================================================================================================
 Fixed-Length Benchmark: cuLA fully-fused (sm90) vs FLA Triton  (GVA when HV > H)
====================================================================================================
      [MHA]  B=1 T=1024 H=64 HV=64
      [MHA]  B=1 T=4096 H=64 HV=64
      [MHA]  B=1 T=8192 H=64 HV=64
      [MHA]  B=1 T=16384 H=64 HV=64
      [MHA]  B=2 T=4096 H=64 HV=64
      [MHA]  B=2 T=8192 H=64 HV=64
   [GVA 4x]  B=1 T=1024 H=16 HV=64
   [GVA 4x]  B=1 T=4096 H=16 HV=64
   [GVA 4x]  B=1 T=8192 H=16 HV=64
   [GVA 4x]  B=1 T=16384 H=16 HV=64
   [GVA 2x]  B=1 T=4096 H=32 HV=64
   [GVA 2x]  B=1 T=8192 H=32 HV=64
   [GVA 8x]  B=1 T=4096 H=8 HV=64
   [GVA 8x]  B=1 T=8192 H=8 HV=64
   [GVA 4x]  B=2 T=4096 H=16 HV=64
   [GVA 4x]  B=2 T=8192 H=16 HV=64

====================================================================================================
 Varlen Benchmark: cuLA fully-fused (sm90) vs FLA Triton  (GVA when HV > H)
====================================================================================================
      [MHA]  uniform 10seqs T=4096 H=64 HV=64
      [MHA]  random 10seqs T=4096 H=64 HV=64
      [MHA]  skewed 10seqs T=4096 H=64 HV=64
      [MHA]  uniform 20seqs T=4096 H=64 HV=64
      [MHA]  random 20seqs T=4096 H=64 HV=64
      [MHA]  skewed 20seqs T=4096 H=64 HV=64
      [MHA]  uniform 10seqs T=8192 H=64 HV=64
      [MHA]  random 10seqs T=8192 H=64 HV=64
      [MHA]  skewed 10seqs T=8192 H=64 HV=64
      [MHA]  uniform 20seqs T=8192 H=64 HV=64
      [MHA]  random 20seqs T=8192 H=64 HV=64
      [MHA]  skewed 20seqs T=8192 H=64 HV=64
   [GVA 4x]  uniform 10seqs T=4096 H=16 HV=64
   [GVA 4x]  random 10seqs T=4096 H=16 HV=64
   [GVA 4x]  skewed 10seqs T=4096 H=16 HV=64
   [GVA 4x]  uniform 20seqs T=4096 H=16 HV=64
   [GVA 4x]  random 20seqs T=4096 H=16 HV=64
   [GVA 4x]  skewed 20seqs T=4096 H=16 HV=64
   [GVA 4x]  uniform 10seqs T=8192 H=16 HV=64
   [GVA 4x]  random 10seqs T=8192 H=16 HV=64
   [GVA 4x]  skewed 10seqs T=8192 H=16 HV=64
   [GVA 4x]  uniform 20seqs T=8192 H=16 HV=64
   [GVA 4x]  random 20seqs T=8192 H=16 HV=64
   [GVA 4x]  skewed 20seqs T=8192 H=16 HV=64
   [GVA 2x]  uniform 10seqs T=4096 H=32 HV=64
   [GVA 2x]  random 10seqs T=4096 H=32 HV=64
   [GVA 2x]  skewed 10seqs T=4096 H=32 HV=64
   [GVA 2x]  uniform 20seqs T=4096 H=32 HV=64
   [GVA 2x]  random 20seqs T=4096 H=32 HV=64
   [GVA 2x]  skewed 20seqs T=4096 H=32 HV=64
   [GVA 2x]  uniform 10seqs T=8192 H=32 HV=64
   [GVA 2x]  random 10seqs T=8192 H=32 HV=64
   [GVA 2x]  skewed 10seqs T=8192 H=32 HV=64
   [GVA 2x]  uniform 20seqs T=8192 H=32 HV=64
   [GVA 2x]  random 20seqs T=8192 H=32 HV=64
   [GVA 2x]  skewed 20seqs T=8192 H=32 HV=64


========================================================================================================================
                  BENCHMARK REPORT: cula_kda_fused_fwd (fully-fused)
                  cuLA sm90 fully-fused vs FLA Triton
                  D=128  dtype=bf16  safe_gate=True  has_init_state=False
                  GVA rows are those with HV > H (H, HV shown per row).
                  Warmup=25  Iters=100
========================================================================================================================

  [Fixed-Length]
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────
    B       T    H   HV   GVA  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────
    1    1024   64   64    no  │    0.000023    0.008721    0.000011  │     0.4526      0.2663     1.70x
    1    4096   64   64    no  │    0.000021    0.010050    0.000009  │     1.4265      0.9424     1.51x
    1    8192   64   64    no  │    0.000019    0.007853    0.000008  │     2.7906      1.8515     1.51x
    1   16384   64   64    no  │    0.000020    0.009662    0.000009  │     5.5181      3.7032     1.49x
    2    4096   64   64    no  │    0.000019    0.007853    0.000008  │     2.7824      1.8680     1.49x
    2    8192   64   64    no  │    0.000020    0.009662    0.000009  │     5.5173      3.6955     1.49x
    1    1024   16   64    4x  │    0.000022    0.010582    0.000010  │     0.4384      0.2652     1.65x
    1    4096   16   64    4x  │    0.000020    0.010811    0.000009  │     1.4257      0.9414     1.51x
    1    8192   16   64    4x  │    0.000021    0.009434    0.000010  │     2.7798      1.8388     1.51x
    1   16384   16   64    4x  │    0.000021    0.008584    0.000009  │     5.5074      3.6815     1.50x
    1    4096   32   64    2x  │    0.000021    0.008108    0.000009  │     1.4173      0.9363     1.51x
    1    8192   32   64    2x  │    0.000022    0.010050    0.000010  │     2.7764      1.8313     1.52x
    1    4096    8   64    8x  │    0.000019    0.008380    0.000008  │     1.4214      0.9347     1.52x
    1    8192    8   64    8x  │    0.000020    0.009950    0.000009  │     2.7803      1.8381     1.51x
    2    4096   16   64    4x  │    0.000021    0.009434    0.000010  │     2.7750      1.8550     1.50x
    2    8192   16   64    4x  │    0.000021    0.008584    0.000009  │     5.4833      3.6633     1.50x
  ──────────────────────────────────────────────────────────────────────────────────────────────────────────────

  [Varlen]
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                         Config    H   HV   GVA  │        RMSE     rel_max   mean_diff  │    FLA(ms)    cuLA(ms)   Speedup
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       uniform 10seqs T=4096 [409..415] avg=409   64   64    no  │    0.000020    0.009950    0.000009  │     1.5009      1.1548     1.30x
        random 10seqs T=4096 [24..1201] avg=409   64   64    no  │    0.000020    0.007538    0.000009  │     1.4825      1.0289     1.44x
       skewed 10seqs T=4096 [227..2053] avg=409   64   64    no  │    0.000020    0.010050    0.000009  │     1.4843      1.0091     1.47x
       uniform 20seqs T=4096 [204..220] avg=204   64   64    no  │    0.000020    0.007538    0.000009  │     1.6396      1.3260     1.24x
          random 20seqs T=4096 [5..787] avg=204   64   64    no  │    0.000020    0.007538    0.000009  │     1.5856      1.1325     1.40x
       skewed 20seqs T=4096 [107..2063] avg=204   64   64    no  │    0.000020    0.010050    0.000009  │     1.5306      1.1067     1.38x
       uniform 10seqs T=8192 [819..821] avg=819   64   64    no  │    0.000018    0.010471    0.000008  │     2.7821      1.9379     1.44x
        random 10seqs T=8192 [48..2401] avg=819   64   64    no  │    0.000018    0.007853    0.000008  │     2.8232      1.8791     1.50x
       skewed 10seqs T=8192 [455..4097] avg=819   64   64    no  │    0.000018    0.010417    0.000008  │     2.8304      1.8865     1.50x
       uniform 20seqs T=8192 [409..421] avg=409   64   64    no  │    0.000018    0.010471    0.000008  │     2.9118      2.1975     1.33x
         random 20seqs T=8192 [9..1574] avg=409   64   64    no  │    0.000018    0.007853    0.000008  │     2.9067      1.9644     1.48x
       skewed 20seqs T=8192 [215..4107] avg=409   64   64    no  │    0.000018    0.007772    0.000008  │     2.9206      2.0287     1.44x
       uniform 10seqs T=4096 [409..415] avg=409   16   64    4x  │    0.000020    0.010811    0.000009  │     1.5033      1.1589     1.30x
        random 10seqs T=4096 [24..1201] avg=409   16   64    4x  │    0.000019    0.010811    0.000009  │     1.4815      1.0303     1.44x
       skewed 10seqs T=4096 [227..2053] avg=409   16   64    4x  │    0.000020    0.010811    0.000009  │     1.4863      1.0107     1.47x
       uniform 20seqs T=4096 [204..220] avg=204   16   64    4x  │    0.000019    0.008108    0.000008  │     1.6396      1.3246     1.24x
          random 20seqs T=4096 [5..787] avg=204   16   64    4x  │    0.000019    0.008108    0.000008  │     1.5864      1.1339     1.40x
       skewed 20seqs T=4096 [107..2063] avg=204   16   64    4x  │    0.000019    0.010811    0.000008  │     1.5314      1.1107     1.38x
       uniform 10seqs T=8192 [819..821] avg=819   16   64    4x  │    0.000021    0.009390    0.000010  │     2.7835      1.9372     1.44x
        random 10seqs T=8192 [48..2401] avg=819   16   64    4x  │    0.000021    0.009434    0.000010  │     2.8234      1.8795     1.50x
       skewed 10seqs T=8192 [455..4097] avg=819   16   64    4x  │    0.000021    0.009434    0.000010  │     2.8294      1.8836     1.50x
       uniform 20seqs T=8192 [409..421] avg=409   16   64    4x  │    0.000021    0.007075    0.000009  │     2.9094      2.1946     1.33x
         random 20seqs T=8192 [9..1574] avg=409   16   64    4x  │    0.000021    0.009479    0.000009  │     2.9052      1.9608     1.48x
       skewed 20seqs T=8192 [215..4107] avg=409   16   64    4x  │    0.000021    0.009434    0.000009  │     2.9204      2.0305     1.44x
       uniform 10seqs T=4096 [409..415] avg=409   32   64    2x  │    0.000021    0.008108    0.000009  │     1.5014      1.1581     1.30x
        random 10seqs T=4096 [24..1201] avg=409   32   64    2x  │    0.000021    0.010811    0.000009  │     1.4841      1.0312     1.44x
       skewed 10seqs T=4096 [227..2053] avg=409   32   64    2x  │    0.000021    0.008108    0.000009  │     1.4842      1.0131     1.46x
       uniform 20seqs T=4096 [204..220] avg=204   32   64    2x  │    0.000020    0.010811    0.000009  │     1.6384      1.3225     1.24x
          random 20seqs T=4096 [5..787] avg=204   32   64    2x  │    0.000020    0.008108    0.000009  │     1.5854      1.1337     1.40x
       skewed 20seqs T=4096 [107..2063] avg=204   32   64    2x  │    0.000020    0.008108    0.000009  │     1.5341      1.1088     1.38x
       uniform 10seqs T=8192 [819..821] avg=819   32   64    2x  │    0.000021    0.010050    0.000010  │     2.7843      1.9374     1.44x
        random 10seqs T=8192 [48..2401] avg=819   32   64    2x  │    0.000021    0.010050    0.000010  │     2.8205      1.8759     1.50x
       skewed 10seqs T=8192 [455..4097] avg=819   32   64    2x  │    0.000021    0.010050    0.000010  │     2.8302      1.8818     1.50x
       uniform 20seqs T=8192 [409..421] avg=409   32   64    2x  │    0.000021    0.009950    0.000009  │     2.9101      2.1988     1.32x
         random 20seqs T=8192 [9..1574] avg=409   32   64    2x  │    0.000021    0.010050    0.000010  │     2.9021      1.9671     1.48x
       skewed 20seqs T=8192 [215..4107] avg=409   32   64    2x  │    0.000021    0.010050    0.000009  │     2.9166      2.0099     1.45x
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

========================================================================================================================

Related

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Grouped Value Attention (GVA) support to the KDA forward prefill kernel for Hopper architectures. The changes involve decoupling Q/K head counts from V/O head counts, updating the tile scheduler to share Q/K within GVA groups, and adjusting TMA load/store logic. The Python interface was also updated with enhanced validation and a fix for the output_final_state return logic. Review feedback identifies a potential optimization to avoid unnecessary memory allocation and GPU bandwidth usage by conditionally skipping output_state initialization when the final state is not needed.

Comment thread csrc/api/kda_sm90.cu Outdated
Comment on lines 67 to 71
torch::Tensor output_state = output_state_.has_value()
? output_state_.value()
: torch::zeros(
{num_seqs, num_heads, head_size, head_size},
{num_seqs, num_v_heads, head_size, head_size},
torch::TensorOptions().dtype(torch::kFloat32).device(q.device()));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output_state tensor is always allocated and zero-initialized even when the caller does not require the final state (i.e., when output_final_state=False in the Python API). For large models or long sequences, this results in significant unnecessary memory allocation and GPU bandwidth consumption during the kernel's write-back phase. Consider passing a flag to the kernel to skip the state store, or at least avoid the allocation in the C++ API if the result is not needed by the caller.

sunnyxyli added 2 commits May 6, 2026 11:30
@KevinZeng08
Copy link
Copy Markdown
Collaborator

Thanks for your contribution. Could you attach the performance results with this PR?

Comment thread tests/test_kda_fused_fwd.py Outdated
Comment thread tests/test_kda_fused_fwd.py Outdated
@sjmshsh
Copy link
Copy Markdown
Author

sjmshsh commented May 6, 2026

Thanks for your contribution. Could you attach the performance results with this PR?

Okay, please wait a moment.

@KevinZeng08
Copy link
Copy Markdown
Collaborator

It seems that the benchmark does not contain a GVA setting. Could you add the GVA data preparation in benchmark.utils and add the GVA benchmark?

@sjmshsh
Copy link
Copy Markdown
Author

sjmshsh commented May 7, 2026

It seems that the benchmark does not contain a GVA setting. Could you add the GVA data preparation in benchmark.utils and add the GVA benchmark?

👌

Comment thread csrc/api/kda_sm90.cu Outdated
Comment thread benchmarks/bench_kda_fused_fwd.py Outdated
Comment thread benchmarks/bench_kda_fused_fwd.py Outdated
Comment thread benchmarks/bench_kda_fused_fwd.py Outdated
@KevinZeng08
Copy link
Copy Markdown
Collaborator

Hi, this PR #66 removes the redundant zero init for output_final_state=False, you can merge it in your kda_sm90 interface, thanks~

@sjmshsh
Copy link
Copy Markdown
Author

sjmshsh commented May 9, 2026

Hi, this PR #66 removes the redundant zero init for output_final_state=False, you can merge it in your kda_sm90 interface, thanks~

OK

sunnyxyli added 3 commits May 9, 2026 11:52
Comment thread benchmarks/utils.py
the effective head counts uniformly.
"""
HV = H if num_v_heads is None else num_v_heads
assert H > 0, f"H must be positive, got {H}."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This two assert H>0, HV>0 can be deleted

Comment on lines +429 to +431
"GVA rows (HV > H) are mixed in alongside MHA rows (HV == H) "
"under the same sequence-length settings, so GVA and MHA can be "
"compared side by side."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment can be deleted

Comment on lines +479 to +489
# GVA (HV > H) at the same (B, T) shapes for side-by-side comparison:
(1, 1024, 16, 64), # 4x
(1, 4096, 16, 64), # 4x
(1, 8192, 16, 64), # 4x
(1, 16384, 16, 64), # 4x
(1, 4096, 32, 64), # 2x
(1, 8192, 32, 64), # 2x
(1, 4096, 8, 64), # 8x
(1, 8192, 8, 64), # 8x
(2, 4096, 16, 64), # 4x
(2, 8192, 16, 64), # 4x
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HV parameter can be specified by user with --hv.
And we have a HV parameter, so these test settings are no longer needed, just restoring to not modify them is OK.

Comment on lines +492 to +504
# Varlen configs — identical sequence-length layouts replayed with and
# without GVA so MHA and GVA can be compared row-by-row.
varlen_configs_base = build_varlen_configs(
num_seqs_list=(10, 20),
total_lens=(4096, 8192, 16384),
total_lens=(4096, 8192),
dists=("uniform", "random", "skewed"),
)
gva_varlen_mixed = [
(seq_lens, T, dist, H_qk, HV)
for (H_qk, HV) in ((16, 64), (32, 64))
for (seq_lens, T, dist) in varlen_configs_base
]
varlen_configs = list(varlen_configs_base) + gva_varlen_mixed
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Comment on lines +162 to +174
# Config normalization helpers
# ============================================================
def _normalize_fixed_config(cfg):
"""Accept either (B, T) or (B, T, H_qk, HV) and return the 4-tuple form.

For the 2-tuple legacy form, defaults to H_qk=HV=H (no GVA).
"""
if len(cfg) == 2:
B, T = cfg
return B, T, H, H
if len(cfg) == 4:
return cfg
raise ValueError(f"Fixed config must be (B, T) or (B, T, H, HV), got {cfg!r}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these normalization helpers are no longer needed as well

Comment thread benchmarks/utils.py
Comment on lines +355 to +356
H=H,
HV=HV,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to return

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants